{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "520bca22",
   "metadata": {},
   "source": [
    "# 1-4,时间序列数据建模流程范例"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "914d0791",
   "metadata": {},
   "source": [
    "2020年发生的新冠肺炎疫情灾难给各国人民的生活造成了诸多方面的影响。\n",
    "\n",
    "有的同学是收入上的，有的同学是感情上的，有的同学是心理上的，还有的同学是体重上的。\n",
    "\n",
    "本文基于中国2020年3月之前的疫情数据，建立时间序列RNN模型，对中国的新冠肺炎疫情结束时间进行预测。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fd11a67",
   "metadata": {},
   "source": [
    "![](./data/疫情前后对比.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5054d82",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torch==1.10.0\n",
    "!pip install pytorch_lightning==1.6.5 \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "306f3ae5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import pytorch_lightning\n",
    "print(\"torch.__version__ = \", torch.__version__)\n",
    "print(\"pytorch_lightning.__version__ = \", pytorch_lightning.__version__) \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88732a82",
   "metadata": {},
   "source": [
    "```\n",
    "torch.__version__ =  1.10.0\n",
    "pytorch_lightning.__version__ =  1.6.5\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73b65b4b",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "<font color=\"red\">\n",
    " \n",
    "公众号 **算法美食屋** 回复关键词：**pytorch**， 获取本项目源码和所用数据集百度云盘下载链接。\n",
    "    \n",
    "</font> \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19e680ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import datetime\n",
    "import torchkeras\n",
    "\n",
    "#打印时间\n",
    "def printbar():\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "\n",
    "#mac系统上pytorch和matplotlib在jupyter中同时跑需要更改环境变量\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\" \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "603a13af",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24772c05",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torch==1.10.0\n",
    "!pip install pytorch_lightning==1.6.5 \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1b79bb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import pytorch_lightning\n",
    "print(\"torch.__version__ = \", torch.__version__)\n",
    "print(\"pytorch_lightning.__version__ = \", pytorch_lightning.__version__) \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c91d093",
   "metadata": {},
   "source": [
    "```\n",
    "torch.__version__ =  1.10.0\n",
    "pytorch_lightning.__version__ =  1.6.5\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dbb69fd",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4d46f9f9",
   "metadata": {},
   "source": [
    "## 一，准备数据"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a5acccb",
   "metadata": {},
   "source": [
    "本文的数据集取自tushare，获取该数据集的方法参考了以下文章。\n",
    "\n",
    "《https://zhuanlan.zhihu.com/p/109556102》\n",
    "\n",
    "![](./data/1-4-新增人数.png)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "738b171a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd \n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "903d0408",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "df = pd.read_csv(\"./eat_pytorch_datasets/covid-19.csv\",sep = \"\\t\")\n",
    "df.plot(x = \"date\",y = [\"confirmed_num\",\"cured_num\",\"dead_num\"],figsize=(10,6))\n",
    "plt.xticks(rotation=60);\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0636bedf",
   "metadata": {},
   "source": [
    "![](./data/1-4-累积曲线.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5a55006",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfdata = df.set_index(\"date\")\n",
    "dfdiff = dfdata.diff(periods=1).dropna()\n",
    "dfdiff = dfdiff.reset_index(\"date\")\n",
    "\n",
    "dfdiff.plot(x = \"date\",y = [\"confirmed_num\",\"cured_num\",\"dead_num\"],figsize=(10,6))\n",
    "plt.xticks(rotation=60)\n",
    "dfdiff = dfdiff.drop(\"date\",axis = 1).astype(\"float32\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ebcb194",
   "metadata": {},
   "source": [
    "![](./data/1-4-新增曲线.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb89487",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfdiff.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77134e8f",
   "metadata": {},
   "source": [
    "![](./data/1-4-dfdiff.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15dc30a0",
   "metadata": {},
   "source": [
    "下面我们通过继承torch.utils.data.Dataset实现自定义时间序列数据集。\n",
    "\n",
    "torch.utils.data.Dataset是一个抽象类，用户想要加载自定义的数据只需要继承这个类，并且覆写其中的两个方法即可：\n",
    "\n",
    "* `__len__`:实现len(dataset)返回整个数据集的大小。\n",
    "* `__getitem__`:用来获取一些索引的数据，使`dataset[i]`返回数据集中第i个样本。\n",
    "\n",
    "不覆写这两个方法会直接返回错误。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d654bf5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "from torch.utils.data import Dataset,DataLoader,TensorDataset\n",
    "\n",
    "\n",
    "#用某日前8天窗口数据作为输入预测该日数据\n",
    "WINDOW_SIZE = 8\n",
    "\n",
    "class Covid19Dataset(Dataset):\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(dfdiff) - WINDOW_SIZE\n",
    "    \n",
    "    def __getitem__(self,i):\n",
    "        x = dfdiff.loc[i:i+WINDOW_SIZE-1,:]\n",
    "        feature = torch.tensor(x.values)\n",
    "        y = dfdiff.loc[i+WINDOW_SIZE,:]\n",
    "        label = torch.tensor(y.values)\n",
    "        return (feature,label)\n",
    "    \n",
    "ds_train = Covid19Dataset()\n",
    "\n",
    "#数据较小，可以将全部训练数据放入到一个batch中，提升性能\n",
    "dl_train = DataLoader(ds_train,batch_size = 38)\n",
    "\n",
    "for features,labels in dl_train:\n",
    "    break \n",
    "    \n",
    "#dl_train同时作为验证集\n",
    "dl_val = dl_train\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f1410eb",
   "metadata": {},
   "source": [
    "## 二，定义模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b39abea",
   "metadata": {},
   "source": [
    "使用Pytorch通常有三种方式构建模型：使用nn.Sequential按层顺序构建模型，继承nn.Module基类构建自定义模型，继承nn.Module基类构建模型并辅助应用模型容器进行封装。\n",
    "\n",
    "此处选择第二种方式构建模型。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42466f8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn \n",
    "import importlib \n",
    "import torchkeras \n",
    "\n",
    "torch.random.seed()\n",
    "\n",
    "class Block(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Block,self).__init__()\n",
    "    \n",
    "    def forward(self,x,x_input):\n",
    "        x_out = torch.max((1+x)*x_input[:,-1,:],torch.tensor(0.0))\n",
    "        return x_out\n",
    "    \n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # 3层lstm\n",
    "        self.lstm = nn.LSTM(input_size = 3,hidden_size = 3,num_layers = 5,batch_first = True)\n",
    "        self.linear = nn.Linear(3,3)\n",
    "        self.block = Block()\n",
    "        \n",
    "    def forward(self,x_input):\n",
    "        x = self.lstm(x_input)[0][:,-1,:]\n",
    "        x = self.linear(x)\n",
    "        y = self.block(x,x_input)\n",
    "        return y\n",
    "        \n",
    "net = Net()\n",
    "print(net)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b05e00cd",
   "metadata": {},
   "source": [
    "```\n",
    "Net(\n",
    "  (lstm): LSTM(3, 3, num_layers=5, batch_first=True)\n",
    "  (linear): Linear(in_features=3, out_features=3, bias=True)\n",
    "  (block): Block()\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e02fe265",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchkeras import summary\n",
    "summary(net,input_data=features);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "730073d3",
   "metadata": {},
   "source": [
    "```\n",
    "--------------------------------------------------------------------------\n",
    "Layer (type)                            Output Shape              Param #\n",
    "==========================================================================\n",
    "LSTM-1                                    [-1, 8, 3]                  480\n",
    "Linear-2                                     [-1, 3]                   12\n",
    "Block-3                                      [-1, 3]                    0\n",
    "==========================================================================\n",
    "Total params: 492\n",
    "Trainable params: 492\n",
    "Non-trainable params: 0\n",
    "--------------------------------------------------------------------------\n",
    "Input size (MB): 0.000069\n",
    "Forward/backward pass size (MB): 0.000229\n",
    "Params size (MB): 0.001877\n",
    "Estimated Total Size (MB): 0.002174\n",
    "--------------------------------------------------------------------------\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4649f9f5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4d8d90bc",
   "metadata": {},
   "source": [
    "### 三，训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39722b11",
   "metadata": {},
   "source": [
    "训练Pytorch通常需要用户编写自定义训练循环，训练循环的代码风格因人而异。\n",
    "\n",
    "有3类典型的训练循环代码风格：脚本形式训练循环，函数形式训练循环，类形式训练循环。\n",
    "\n",
    "此处介绍一种引进pytorch_lightning库实现的类形式的训练循环。\n",
    "\n",
    "该训练循环的代码也是torchkeras库中LightModel类的核心代码。\n",
    "\n",
    "torchkeras详情:  https://github.com/lyhue1991/torchkeras \n",
    "\n",
    "注：循环神经网络调试较为困难，需要设置多个不同的学习率多次尝试，以取得较好的效果。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93701678",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f4d1cb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "import pytorch_lightning as pl\n",
    "import datetime\n",
    "import sys \n",
    "import numpy as np\n",
    "import pandas as pd \n",
    "from copy import deepcopy\n",
    "\n",
    "class LightModel(pl.LightningModule):\n",
    "    def __init__(self,net,loss_fn,metrics_dict=None,optimizer=None,lr_scheduler=None):\n",
    "        super().__init__()\n",
    "        self.net = net\n",
    "        self.history = {}\n",
    "        \n",
    "        self.train_metrics = nn.ModuleDict(metrics_dict)\n",
    "        self.val_metrics = deepcopy(self.train_metrics)\n",
    "        self.test_metrics = deepcopy(self.train_metrics)\n",
    "        \n",
    "        self.loss_fn = loss_fn\n",
    "        self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(self.parameters(), lr=1e-2)\n",
    "        self.lr_scheduler = lr_scheduler \n",
    "        \n",
    "        for p in [\"net\",\"loss_fn\",\"metrics_dict\",\"optimizer\",\"lr_scheduler\"]:\n",
    "            self.save_hyperparameters(p)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        if self.net:\n",
    "            return self.net.forward(x)\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "            \n",
    "    def shared_step(self,batch,batch_idx):\n",
    "        x, y = batch\n",
    "        preds = self(x)\n",
    "        loss = self.loss_fn(preds,y)\n",
    "        return {'loss': loss, 'preds': preds.detach(), 'y': y.detach()}\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        if self.lr_scheduler is None:\n",
    "            return self.optimizer\n",
    "        return {\"optimizer\":self.optimizer,\"lr_scheduler\":self.lr_scheduler}\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        return self.shared_step(batch,batch_idx)\n",
    "    \n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        return self.shared_step(batch,batch_idx)\n",
    "    \n",
    "    def test_step(self, batch, batch_idx):\n",
    "        return self.shared_step(batch,batch_idx)\n",
    "    \n",
    "    def predict_step(self, batch, batch_idx):\n",
    "        if isinstance(batch,list) and len(batch)==2:\n",
    "            return self(batch[0])\n",
    "        else:\n",
    "            return self(batch)\n",
    "        \n",
    "    def shared_step_end(self,outputs,stage):\n",
    "        metrics = self.train_metrics if stage==\"train\" else (\n",
    "            self.val_metrics if stage==\"val\" else self.test_metrics)\n",
    "        for name in metrics:\n",
    "            step_metric = metrics[name](outputs['preds'], outputs['y']).item()\n",
    "            if stage==\"train\":\n",
    "                self.log(name,step_metric,prog_bar=True)\n",
    "        return outputs[\"loss\"].mean()\n",
    "        \n",
    "    def training_step_end(self, outputs):\n",
    "        return {'loss':self.shared_step_end(outputs,\"train\")}\n",
    "            \n",
    "    def validation_step_end(self, outputs):\n",
    "        return {'val_loss':self.shared_step_end(outputs,\"val\")}\n",
    "            \n",
    "    def test_step_end(self, outputs):\n",
    "        return {'test_loss':self.shared_step_end(outputs,\"test\")}\n",
    "            \n",
    "    def shared_epoch_end(self,outputs,stage=\"train\"):\n",
    "        metrics = self.train_metrics if stage==\"train\" else (\n",
    "            self.val_metrics if stage==\"val\" else self.test_metrics)\n",
    "        \n",
    "        epoch = self.trainer.current_epoch\n",
    "        stage_loss = torch.mean(torch.tensor([t[(stage+\"_loss\").replace('train_','')] for t in outputs])).item()\n",
    "        dic = {\"epoch\":epoch,stage+\"_loss\":stage_loss}\n",
    "        \n",
    "        for name in metrics:\n",
    "            epoch_metric = metrics[name].compute().item() \n",
    "            metrics[name].reset()\n",
    "            dic[stage+\"_\"+name] = epoch_metric \n",
    "        if stage!='test':\n",
    "            self.history[epoch] = dict(self.history.get(epoch,{}),**dic)    \n",
    "        return dic \n",
    "    \n",
    "    def training_epoch_end(self, outputs):\n",
    "        dic = self.shared_epoch_end(outputs,stage=\"train\")\n",
    "        self.print(dic)\n",
    "        dic.pop(\"epoch\",None)\n",
    "        self.log_dict(dic, logger=True)\n",
    "\n",
    "    def validation_epoch_end(self, outputs):\n",
    "        dic = self.shared_epoch_end(outputs,stage=\"val\")\n",
    "        self.print_bar()\n",
    "        self.print(dic)\n",
    "        dic.pop(\"epoch\",None)\n",
    "        self.log_dict(dic, logger=True)\n",
    "        \n",
    "        #log when reach best score\n",
    "        ckpt_cb = self.trainer.checkpoint_callback\n",
    "        monitor = ckpt_cb.monitor \n",
    "        mode = ckpt_cb.mode \n",
    "        arr_scores = self.get_history()[monitor]\n",
    "        best_score_idx = np.argmax(arr_scores) if mode==\"max\" else np.argmin(arr_scores)\n",
    "        if best_score_idx==len(arr_scores)-1:   \n",
    "            self.print(\"<<<<<< reach best {0} : {1} >>>>>>\".format(monitor,\n",
    "                 arr_scores[best_score_idx]),file=sys.stderr)\n",
    "            \n",
    "    \n",
    "    def test_epoch_end(self, outputs):\n",
    "        dic = self.shared_epoch_end(outputs,stage=\"test\")\n",
    "        dic.pop(\"epoch\",None)\n",
    "        self.print(dic)\n",
    "        self.log_dict(dic, logger=True)\n",
    "        \n",
    "    def get_history(self):\n",
    "        return pd.DataFrame(self.history.values()) \n",
    "    \n",
    "    def print_bar(self): \n",
    "        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "        self.print(\"\\n\"+\"=\"*80 + \"%s\"%nowtime)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bf9b113",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchmetrics.regression import MeanAbsolutePercentageError\n",
    "\n",
    "def mspe(y_pred,y_true):\n",
    "    err_percent = (y_true - y_pred)**2/(torch.max(y_true**2,torch.tensor(1e-7)))\n",
    "    return torch.mean(err_percent)\n",
    "\n",
    "\n",
    "net = Net() \n",
    "loss_fn = mspe\n",
    "metric_dict = {\"mape\":MeanAbsolutePercentageError()}\n",
    "\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=0.03)\n",
    "lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.0001)\n",
    "\n",
    "model = LightModel(net,\n",
    "       loss_fn = loss_fn,\n",
    "       metrics_dict= metric_dict,\n",
    "       optimizer = optimizer,\n",
    "       lr_scheduler = lr_scheduler)       \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8185438",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl     \n",
    "\n",
    "#1，设置回调函数\n",
    "\n",
    "model_ckpt = pl.callbacks.ModelCheckpoint(\n",
    "    monitor='val_mape',\n",
    "    save_top_k=1,\n",
    "    mode='min'\n",
    ")\n",
    "\n",
    "early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_mape',\n",
    "                           patience=3,\n",
    "                           mode = 'min'\n",
    "                          )\n",
    "\n",
    "#2，设置训练参数\n",
    "# gpus=0 则使用cpu训练，gpus=1则使用1个gpu训练，gpus=2则使用2个gpu训练，gpus=-1则使用所有gpu训练，\n",
    "# gpus=[0,1]则指定使用0号和1号gpu训练， gpus=\"0,1,2,3\"则使用0,1,2,3号gpu训练\n",
    "# tpus=1 则使用1个tpu训练\n",
    "trainer = pl.Trainer(logger=True,\n",
    "                     min_epochs=3,max_epochs=30,\n",
    "                     gpus=0,\n",
    "                     callbacks = [model_ckpt,early_stopping],\n",
    "                     enable_progress_bar = True) \n",
    "\n",
    "\n",
    "##3，启动训练循环\n",
    "trainer.fit(model,dl_train,dl_val)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cc46c53",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfhistory = model.get_history() \n",
    "dfhistory "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7bb0fcf",
   "metadata": {},
   "source": [
    "```\n",
    "epoch\tval_loss\tval_mape\ttrain_loss\ttrain_mape\n",
    "0\t0\t5.974455\t0.661542\t6.936645\t0.737031\n",
    "1\t1\t5.086240\t0.590996\t5.974455\t0.661542\n",
    "2\t2\t4.237600\t0.524000\t5.086240\t0.590996\n",
    "3\t3\t3.408828\t0.463179\t4.237600\t0.524000\n",
    "4\t4\t2.614143\t0.422679\t3.408828\t0.463179\n",
    "5\t5\t1.896354\t0.413116\t2.614143\t0.422679\n",
    "6\t6\t1.304007\t0.437700\t1.896354\t0.413116\n",
    "7\t7\t0.866170\t0.474878\t1.304007\t0.437700\n",
    "8\t8\t0.585183\t0.517995\t0.866170\t0.474878\n",
    "\t\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "792e4f1a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ed66a496",
   "metadata": {},
   "source": [
    "### 四，评估模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b51cbde5",
   "metadata": {},
   "source": [
    "评估模型一般要设置验证集或者测试集，由于此例数据较少，我们仅仅可视化损失函数在训练集上的迭代情况。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb5f075f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_metric(dfhistory, metric):\n",
    "    train_metrics = dfhistory[\"train_\"+metric]\n",
    "    val_metrics = dfhistory['val_'+metric]\n",
    "    epochs = range(1, len(train_metrics) + 1)\n",
    "    plt.plot(epochs, train_metrics, 'bo--')\n",
    "    plt.plot(epochs, val_metrics, 'ro-')\n",
    "    plt.title('Training and validation '+ metric)\n",
    "    plt.xlabel(\"Epochs\")\n",
    "    plt.ylabel(metric)\n",
    "    plt.legend([\"train_\"+metric, 'val_'+metric])\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46963c00",
   "metadata": {},
   "outputs": [],
   "source": [
    "#使用最佳保存点进行评估\n",
    "trainer.test(ckpt_path='best', dataloaders=dl_val,verbose = False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c6a30ce",
   "metadata": {},
   "source": [
    "```\n",
    "{'test_loss': 1.8963541984558105, 'test_mape': 0.4131162464618683}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f1a46fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"loss\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "770aa7ec",
   "metadata": {},
   "source": [
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h48rhp9jpqj20ej0acwer.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fce9cf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"mape\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6f575f1",
   "metadata": {},
   "source": [
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h48rj15ctvj20f70ag3yv.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdc717dd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "2687ac02",
   "metadata": {},
   "source": [
    "### 五，使用模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59e0db6a",
   "metadata": {},
   "source": [
    "此处我们使用模型预测疫情结束时间，即 新增确诊病例为0 的时间。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4164d6eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#使用dfresult记录现有数据以及此后预测的疫情数据\n",
    "dfresult = dfdiff[[\"confirmed_num\",\"cured_num\",\"dead_num\"]].copy()\n",
    "dfresult.tail()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1b9f003",
   "metadata": {},
   "source": [
    "![](./data/1-4-日期3月10.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "727b7d0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#预测此后1000天的新增走势,将其结果添加到dfresult中\n",
    "for i in range(1000):\n",
    "    arr_input = torch.unsqueeze(torch.from_numpy(dfresult.values[-38:,:]),axis=0)\n",
    "    arr_predict = model.forward(arr_input)\n",
    "\n",
    "    dfpredict = pd.DataFrame(torch.floor(arr_predict).data.numpy(),\n",
    "                columns = dfresult.columns)\n",
    "    dfresult = pd.concat([dfresult,dfpredict],ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d081b013",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfresult.query(\"confirmed_num==0\").head()\n",
    "\n",
    "# 第50天开始新增确诊降为0，第45天对应3月10日，也就是5天后，即预计3月15日新增确诊降为0\n",
    "# 注：该预测偏乐观"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d954a157",
   "metadata": {},
   "source": [
    "![](./data/1-4-torch预测确诊.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad96908f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc5477ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfresult.query(\"cured_num==0\").head()\n",
    "\n",
    "# 第132天开始新增治愈降为0，第45天对应3月10日，也就是大概3个月后，即6月10日左右全部治愈。\n",
    "# 注: 该预测偏悲观，并且存在问题，如果将每天新增治愈人数加起来，将超过累计确诊人数。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f3c999a",
   "metadata": {},
   "source": [
    "![](./data/1-4-torch预测治愈.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da13d833",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ec4f9500",
   "metadata": {},
   "source": [
    "### 六，保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a78086b",
   "metadata": {},
   "source": [
    "模型保存在了trainer.checkpoint_callback.best_model_path路径。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "239f6463",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(trainer.checkpoint_callback.best_model_path)\n",
    "print(trainer.checkpoint_callback.best_model_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f5bef85",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pytorch_lightning不仅保留了模型参数，还保存了模型结构，可以用LightModel重新加载\n",
    "model_loaded = LightModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "748f2b69",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.predict(model_loaded,dataloaders=dl_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06790009",
   "metadata": {},
   "source": [
    "```\n",
    "[tensor([[1.4974e+03, 8.5825e+01, 2.3000e+01],\n",
    "         [1.8469e+03, 6.7295e+01, 4.6768e+01],\n",
    "         [2.0153e+03, 1.4337e+02, 4.3701e+01],\n",
    "         [2.3063e+03, 1.5312e+02, 4.9068e+01],\n",
    "         [2.7721e+03, 2.5358e+02, 4.9834e+01],\n",
    "         [2.6352e+03, 2.5455e+02, 5.5968e+01],\n",
    "         [2.2421e+03, 3.7744e+02, 5.5968e+01],\n",
    "         [2.4147e+03, 4.9740e+02, 6.5935e+01],\n",
    "         [1.8918e+03, 5.8420e+02, 6.8235e+01],\n",
    "         [2.1208e+03, 6.1638e+02, 7.4368e+01],\n",
    "         [1.7599e+03, 6.9733e+02, 8.2802e+01],\n",
    "         [1.4374e+03, 7.2561e+02, 7.4368e+01],\n",
    "         [1.0808e+04, 1.1421e+03, 1.9474e+02],\n",
    "         [2.8870e+03, 7.9193e+02, 9.9669e+00],\n",
    "         [1.8840e+03, 1.3391e+03, 1.0964e+02],\n",
    "         [1.4324e+03, 1.2903e+03, 1.0887e+02],\n",
    "         [1.4610e+03, 1.3898e+03, 8.0502e+01],\n",
    "         [1.3468e+03, 1.6658e+03, 7.5135e+01],\n",
    "         [1.2477e+03, 1.7789e+03, 1.0427e+02],\n",
    "         [2.7895e+02, 1.7353e+03, 8.7424e+01],\n",
    "         [6.3425e+02, 2.0573e+03, 9.0511e+01],\n",
    "         [5.8716e+02, 2.3365e+03, 8.3624e+01],\n",
    "         [4.6230e+02, 2.1747e+03, 7.4431e+01],\n",
    "         [1.5267e+02, 1.8011e+03, 1.1512e+02],\n",
    "         [3.6241e+02, 2.5261e+03, 5.4495e+01],\n",
    "         [2.8963e+02, 2.3633e+03, 3.9916e+01],\n",
    "         [3.0889e+02, 2.6834e+03, 2.2262e+01],\n",
    "         [2.3327e+02, 3.5342e+03, 3.3777e+01],\n",
    "         [3.0461e+02, 2.8151e+03, 3.6080e+01],\n",
    "         [4.0876e+02, 2.5594e+03, 2.6868e+01],\n",
    "         [1.4410e+02, 2.7683e+03, 3.2242e+01],\n",
    "         [8.9172e+01, 2.6756e+03, 2.3798e+01],\n",
    "         [8.4892e+01, 2.5877e+03, 2.9171e+01],\n",
    "         [9.9159e+01, 2.1360e+03, 2.3798e+01],\n",
    "         [1.0201e+02, 1.6403e+03, 2.3030e+01],\n",
    "         [7.0624e+01, 1.6373e+03, 2.1495e+01],\n",
    "         [3.1389e+01, 1.6208e+03, 2.0727e+01],\n",
    "         [2.8535e+01, 1.4978e+03, 1.6889e+01]])]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39cb303a",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,md",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
